{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "torch.manual_seed(24)\n",
    "\n",
    "pos = torch.rand(4, 2)\n",
    "rot = torch.rand(4, 1) * torch.pi * 2\n",
    "d = torch.cat([rot.cos(), rot.sin()], dim=-1)\n",
    "x, y = pos.T\n",
    "u, v = d.T\n",
    "\n",
    "plt.quiver(x, y, u, v)\n",
    "plt.xlim(-0.2, 1.2)\n",
    "plt.ylim(-0.2, 1.2)\n",
    "\n",
    "pos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def separation(p0, p1, p1_d):\n",
    "    rel_pos = rel_pos =  p1.unsqueeze(0) - p0.unsqueeze(1)\n",
    "    z_distance = (rel_pos * p1_d).sum(-1, keepdim=True)\n",
    "    z_displacement = z_distance * p1_d\n",
    "\n",
    "    r_displacement = rel_pos - z_displacement\n",
    "    r_distance = torch.norm(r_displacement, dim=-1, keepdim=True)\n",
    "    return z_distance, r_distance\n",
    "    \n",
    "def downwash(p0, p1, p1_d, kr=2, kz=1):\n",
    "    \"\"\"\n",
    "    p0: [n, d]\n",
    "    p1: [m, d]\n",
    "    p1: [m, d]\n",
    "    \"\"\"\n",
    "    z, r = separation(p0, p1, p1_d)\n",
    "    z = torch.clip(z, 0)\n",
    "    v = torch.exp(-0.5 * torch.square(kr * r / z)) / (1 + kz * z)**2\n",
    "    f = v * - p1_d\n",
    "    return f\n",
    "\n",
    "def off_diag(a: torch.Tensor) -> torch.Tensor:\n",
    "    assert a.shape[0] == a.shape[1]\n",
    "    n = a.shape[0]\n",
    "    return a.flatten(0, 1)[1:].unflatten(0, (n-1, n+1))[:, :-1].reshape(n, n-1, *a.shape[2:])\n",
    "\n",
    "f = downwash(pos, pos, d)\n",
    "f.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xx = torch.linspace(0, 1, 20)\n",
    "to = torch.stack(torch.meshgrid(xx, xx), dim=-1).flatten(0, 1)\n",
    "f = downwash(to, pos[[1]], d[[1]], kr=2).squeeze()\n",
    "\n",
    "plt.quiver(*to.T, *f.T)\n",
    "plt.xlim(-0.2, 1.2)\n",
    "plt.ylim(-0.2, 1.2)\n",
    "f.norm(dim=-1).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.linspace(0, 3, 30)\n",
    "plt.plot(x, 1/(x+1))\n",
    "plt.plot(x, 1/(x*x+1))\n",
    "plt.plot(x, 1/(x+1)**2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "ecd0e0e6c69168b243675a580aefa75812739d7c9ce65cb4caa687a3edf0bcc7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
